{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "aec0fde7",
   "metadata": {},
   "source": [
    "# T0. trainer 和 evaluator 的基本使用\n",
    "\n",
    "&emsp; 1 &ensp; trainer 和 evaluator 的基本关系\n",
    " \n",
    "&emsp; &emsp; 1.1 &ensp; trainer 和 evaluater 的初始化\n",
    "\n",
    "&emsp; &emsp; 1.2 &ensp; driver 的含义与使用要求\n",
    "\n",
    "&emsp; &emsp; 1.3 &ensp; trainer 内部初始化 evaluater\n",
    "\n",
    "&emsp; 2 &ensp; 使用 fastNLP 搭建 argmax 模型\n",
    "\n",
    "&emsp; &emsp; 2.1 &ensp; trainer_step 和 evaluator_step\n",
    "\n",
    "&emsp; &emsp; 2.2 &ensp; trainer 和 evaluator 的参数匹配\n",
    "\n",
    "&emsp; &emsp; 2.3 &ensp; 示例：argmax 模型的搭建\n",
    "\n",
    "&emsp; 3 &ensp; 使用 fastNLP 训练 argmax 模型\n",
    " \n",
    "&emsp; &emsp; 3.1 &ensp; trainer 外部初始化的 evaluator\n",
    "\n",
    "&emsp; &emsp; 3.2 &ensp; trainer 内部初始化的 evaluator "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09ea669a",
   "metadata": {},
   "source": [
    "## 1. trainer 和 evaluator 的基本关系\n",
    "\n",
    "### 1.1  trainer 和 evaluator 的初始化\n",
    "\n",
    "在`fastNLP 0.8`中，**`Trainer`模块和`Evaluator`模块分别表示“训练器”和“评测器”**\n",
    "\n",
    "&emsp; 对应于之前的`fastNLP`版本中的`Trainer`模块和`Tester`模块，其定义方法如下所示\n",
    "\n",
    "在`fastNLP 0.8`中，需要注意，在同个`python`脚本中先使用`Trainer`训练，然后使用`Evaluator`评测\n",
    "\n",
    "&emsp; 非常关键的问题在于**如何正确设置二者的`driver`**。这就引入了另一个问题：什么是 `driver`？\n",
    "\n",
    "\n",
    "```python\n",
    "trainer = Trainer(\n",
    "        model=model,                        # 模型基于 torch.nn.Module\n",
    "        train_dataloader=train_dataloader,  # 加载模块基于 torch.utils.data.DataLoader  \n",
    "        optimizers=optimizer,               # 优化模块基于 torch.optim.*\n",
    "        ...\n",
    "        driver=\"torch\",                     # 使用 pytorch 模块进行训练 \n",
    "        device='cuda',                      # 使用 GPU：0 显卡执行训练\n",
    "        ...\n",
    "    )\n",
    "...\n",
    "evaluator = Evaluator(\n",
    "        model=model,                        # 模型基于 torch.nn.Module\n",
    "        dataloaders=evaluate_dataloader,    # 加载模块基于 torch.utils.data.DataLoader\n",
    "        metrics={'acc': Accuracy()},        # 测评方法使用 fastNLP.core.metrics.Accuracy \n",
    "        ...\n",
    "        driver=trainer.driver,              # 保持同 trainer 的 driver 一致\n",
    "        device=None,\n",
    "        ...\n",
    "    )\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c11fe1a",
   "metadata": {},
   "source": [
    "### 1.2  driver 的含义与使用要求\n",
    "\n",
    "在`fastNLP 0.8`中，**`driver`**这一概念被用来表示**控制具体训练的各个步骤的最终执行部分**\n",
    "\n",
    "&emsp; 例如神经网络前向、后向传播的具体执行、网络参数的优化和数据在设备间的迁移等\n",
    "\n",
    "在`fastNLP 0.8`中，**`Trainer`和`Evaluator`都依赖于具体的`driver`来完成整体的工作流程**\n",
    "\n",
    "&emsp; 具体`driver`与`Trainer`以及`Evaluator`之间的关系之后`tutorial 4`中的详细介绍\n",
    "\n",
    "注：这里给出一条建议：**在同一脚本中**，**所有的`Trainer`和`Evaluator`使用的`driver`应当保持一致**\n",
    "\n",
    "&emsp; 尽量不出现，之前使用单卡的`driver`，后面又使用多卡的`driver`，这是因为，当脚本执行至\n",
    "\n",
    "&emsp; 多卡`driver`处时，会重启一个进程执行之前所有内容，如此一来可能会造成一些意想不到的麻烦"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2cac4a1a",
   "metadata": {},
   "source": [
    "### 1.3 Trainer 内部初始化 Evaluator\n",
    "\n",
    "在`fastNLP 0.8`中，如果在**初始化`Trainer`时**，**传入参数`evaluator_dataloaders`和`metrics`**\n",
    "\n",
    "&emsp; 则在`Trainer`内部，也会初始化单独的`Evaluator`来帮助训练过程中对验证集的评测\n",
    "\n",
    "```python\n",
    "trainer = Trainer(\n",
    "        model=model,\n",
    "        train_dataloader=train_dataloader,\n",
    "        optimizers=optimizer,\n",
    "        ...\n",
    "        driver=\"torch\",\n",
    "        device='cuda',\n",
    "        ...\n",
    "        evaluate_dataloaders=evaluate_dataloader,   # 传入参数 evaluator_dataloaders\n",
    "        metrics={'acc': Accuracy()},                # 传入参数 metrics\n",
    "        ...\n",
    "    )\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c9c7dda",
   "metadata": {},
   "source": [
    "## 2. argmax 模型的搭建实例"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "524ac200",
   "metadata": {},
   "source": [
    "### 2.1 trainer_step 和 evaluator_step\n",
    "\n",
    "在`fastNLP 0.8`中，使用`pytorch.nn.Module`搭建需要训练的模型，在搭建模型过程中，除了\n",
    "\n",
    "&emsp; 添加`pytorch`要求的`forward`方法外，还需要添加 **`train_step`** 和 **`evaluate_step`** 这两个方法\n",
    "\n",
    "```python\n",
    "class Model(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Model, self).__init__()\n",
    "        self.loss_fn = torch.nn.CrossEntropyLoss()\n",
    "        pass\n",
    "\n",
    "    def forward(self, x):\n",
    "        pass\n",
    "\n",
    "    def train_step(self, x, y):\n",
    "        pred = self(x)\n",
    "        return {\"loss\": self.loss_fn(pred, y)}\n",
    "\n",
    "    def evaluate_step(self, x, y):\n",
    "        pred = self(x)\n",
    "        pred = torch.max(pred, dim=-1)[1]\n",
    "        return {\"pred\": pred, \"target\": y}\n",
    "```\n",
    "***\n",
    "在`fastNLP 0.8`中，**函数`train_step`是`Trainer`中参数`train_fn`的默认值**\n",
    "\n",
    "&emsp; 由于，在`Trainer`训练时，**`Trainer`通过参数`train_fn`对应的模型方法获得当前数据批次的损失值**\n",
    "\n",
    "&emsp; 因此，在`Trainer`训练时，`Trainer`首先会寻找模型是否定义了`train_step`这一方法\n",
    "\n",
    "&emsp; &emsp; 如果没有找到，那么`Trainer`会默认使用模型的`forward`函数来进行训练的前向传播过程\n",
    "\n",
    "注：在`fastNLP 0.8`中，**`Trainer`要求模型通过`train_step`来返回一个字典**，**满足如`{\"loss\": loss}`的形式**\n",
    "\n",
    "&emsp; 此外，这里也可以通过传入`Trainer`的参数`output_mapping`来实现输出的转换，详见（trainer的详细讲解，待补充）\n",
    "\n",
    "同样，在`fastNLP 0.8`中，**函数`evaluate_step`是`Evaluator`中参数`evaluate_fn`的默认值**\n",
    "\n",
    "&emsp; 在`Evaluator`测试时，**`Evaluator`通过参数`evaluate_fn`对应的模型方法获得当前数据批次的评测结果**\n",
    "\n",
    "&emsp; 从用户角度，模型通过`evaluate_step`方法来返回一个字典，内容与传入`Evaluator`的`metrics`一致\n",
    "\n",
    "&emsp; 从模块角度，该字典的键值和`metric`中的`update`函数的签名一致，这样的机制在传参时被称为“**参数匹配**”\n",
    "\n",
    "<img src=\"./figures/T0-fig-training-structure.png\" width=\"68%\" height=\"68%\" align=\"center\"></img>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb3272eb",
   "metadata": {},
   "source": [
    "### 2.2 trainer 和 evaluator 的参数匹配\n",
    "\n",
    "在`fastNLP 0.8`中，参数匹配涉及到两个方面，分别是在\n",
    "\n",
    "&emsp; 一方面，**在模型的前向传播中**，**`dataloader`向`train_step`或`evaluate_step`函数传递`batch`**\n",
    "\n",
    "&emsp; 另方面，**在模型的评测过程中**，**`evaluate_dataloader`向`metric`的`update`函数传递`batch`**\n",
    "\n",
    "对于前者，在`Trainer`和`Evaluator`中的参数`model_wo_auto_param_call`被设置为`False`时\n",
    "\n",
    "&emsp; &emsp; **`fastNLP 0.8`要求`dataloader`生成的每个`batch`**，**满足如`{\"x\": x, \"y\": y}`的形式**\n",
    "\n",
    "&emsp; 同时，`fastNLP 0.8`会查看模型的`train_step`和`evaluate_step`方法的参数签名，并为对应参数传入对应数值\n",
    "\n",
    "&emsp; &emsp; **字典形式的定义**，**对应在`Dataset`定义的`__getitem__`方法中**，例如下方的`ArgMaxDatset`\n",
    "\n",
    "&emsp; 而在`Trainer`和`Evaluator`中的参数`model_wo_auto_param_call`被设置为`True`时\n",
    "\n",
    "&emsp; &emsp; `fastNLP 0.8`会将`batch`直接传给模型的`train_step`、`evaluate_step`或`forward`函数\n",
    "\n",
    "```python\n",
    "class Dataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, x, y):\n",
    "        self.x = x\n",
    "        self.y = y\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.x)\n",
    "\n",
    "    def __getitem__(self, item):\n",
    "        return {\"x\": self.x[item], \"y\": self.y[item]}\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5f1a6aa",
   "metadata": {},
   "source": [
    "对于后者，首先要明确，在`Trainer`和`Evaluator`中，`metrics`的计算分为`update`和`get_metric`两步\n",
    "\n",
    "&emsp; &emsp; **`update`函数**，**针对一个`batch`的预测结果**，计算其累计的评价指标\n",
    "\n",
    "&emsp; &emsp; **`get_metric`函数**，**统计`update`函数累计的评价指标**，来计算最终的评价结果\n",
    "\n",
    "&emsp; 例如对于`Accuracy`来说，`update`函数会更新一个`batch`的正例数量`right_num`和负例数量`total_num`\n",
    "\n",
    "&emsp; &emsp; 而`get_metric`函数则会返回所有`batch`的评测值`right_num / total_num`\n",
    "\n",
    "&emsp; 在此基础上，**`fastNLP 0.8`要求`evaluate_dataloader`生成的每个`batch`传递给对应的`metric`**\n",
    "\n",
    "&emsp; &emsp; **以`{\"pred\": y_pred, \"target\": y_true}`的形式**，对应其`update`函数的函数签名\n",
    "\n",
    "<img src=\"./figures/T0-fig-parameter-matching.png\" width=\"75%\" height=\"75%\" align=\"center\"></img>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f62b7bb1",
   "metadata": {},
   "source": [
    "### 2.3 示例：argmax 模型的搭建\n",
    "\n",
    "下文将通过训练`argmax`模型，简单介绍如何`Trainer`模块的使用方式\n",
    "\n",
    "&emsp; 首先，使用`pytorch.nn.Module`定义`argmax`模型，目标是输入一组固定维度的向量，输出其中数值最大的数的索引"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5314482b",
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class ArgMaxModel(nn.Module):\n",
    "    def __init__(self, num_labels, feature_dimension):\n",
    "        nn.Module.__init__(self)\n",
    "        self.num_labels = num_labels\n",
    "\n",
    "        self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10)\n",
    "        self.ac1 = nn.ReLU()\n",
    "        self.linear2 = nn.Linear(in_features=10, out_features=10)\n",
    "        self.ac2 = nn.ReLU()\n",
    "        self.output = nn.Linear(in_features=10, out_features=num_labels)\n",
    "        self.loss_fn = nn.CrossEntropyLoss()\n",
    "\n",
    "    def forward(self, x):\n",
    "        pred = self.ac1(self.linear1(x))\n",
    "        pred = self.ac2(self.linear2(pred))\n",
    "        pred = self.output(pred)\n",
    "        return pred\n",
    "\n",
    "    def train_step(self, x, y):\n",
    "        pred = self(x)\n",
    "        return {\"loss\": self.loss_fn(pred, y)}\n",
    "\n",
    "    def evaluate_step(self, x, y):\n",
    "        pred = self(x)\n",
    "        pred = torch.max(pred, dim=-1)[1]\n",
    "        return {\"pred\": pred, \"target\": y}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71f3fa6b",
   "metadata": {},
   "source": [
    "&emsp; 接着，使用`torch.utils.data.Dataset`定义`ArgMaxDataset`数据集\n",
    "\n",
    "&emsp; &emsp; 数据集包含三个参数：维度`feature_dimension`、数据量`data_num`和随机种子`seed`\n",
    "\n",
    "&emsp; &emsp; 数据及初始化是，自动生成指定维度的向量，并为每个向量标注出其中最大值的索引作为预测标签"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fe612e61",
   "metadata": {
    "pycharm": {
     "is_executing": false
    }
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset\n",
    "\n",
    "class ArgMaxDataset(Dataset):\n",
    "    def __init__(self, feature_dimension, data_num=1000, seed=0):\n",
    "        self.num_labels = feature_dimension\n",
    "        self.feature_dimension = feature_dimension\n",
    "        self.data_num = data_num\n",
    "        self.seed = seed\n",
    "\n",
    "        g = torch.Generator()\n",
    "        g.manual_seed(1000)\n",
    "        self.x = torch.randint(low=-100, high=100, size=[data_num, feature_dimension], generator=g).float()\n",
    "        self.y = torch.max(self.x, dim=-1)[1]\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.data_num\n",
    "\n",
    "    def __getitem__(self, item):\n",
    "        return {\"x\": self.x[item], \"y\": self.y[item]}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2cb96332",
   "metadata": {},
   "source": [
    "&emsp; 然后，根据`ArgMaxModel`类初始化模型实例，保持输入维度`feature_dimension`和输出标签数量`num_labels`一致\n",
    "\n",
    "&emsp; &emsp; 再根据`ArgMaxDataset`类初始化两个数据集实例，分别用来模型测试和模型评测，数据量各1000笔"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "76172ef8",
   "metadata": {
    "pycharm": {
     "is_executing": false
    }
   },
   "outputs": [],
   "source": [
    "model = ArgMaxModel(num_labels=10, feature_dimension=10)\n",
    "\n",
    "train_dataset = ArgMaxDataset(feature_dimension=10, data_num=1000)\n",
    "evaluate_dataset = ArgMaxDataset(feature_dimension=10, data_num=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e7d25ee",
   "metadata": {},
   "source": [
    "&emsp; 此外，使用`torch.utils.data.DataLoader`初始化两个数据加载模块，批量大小同为8，分别用于训练和测评"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "363b5b09",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)\n",
    "evaluate_dataloader = DataLoader(evaluate_dataset, batch_size=8)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8d4443f",
   "metadata": {},
   "source": [
    "&emsp; 最后，使用`torch.optim.SGD`初始化一个优化模块，基于随机梯度下降法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "dc28a2d9",
   "metadata": {
    "pycharm": {
     "is_executing": false
    }
   },
   "outputs": [],
   "source": [
    "from torch.optim import SGD\n",
    "\n",
    "optimizer = SGD(model.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb8ca6cf",
   "metadata": {},
   "source": [
    "## 3. 使用 fastNLP 0.8 训练 argmax 模型\n",
    "\n",
    "### 3.1 trainer 外部初始化的 evaluator"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55145553",
   "metadata": {},
   "source": [
    "通过从`fastNLP`库中导入`Trainer`类，初始化`trainer`实例，对模型进行训练\n",
    "\n",
    "&emsp; 需要导入预先定义好的模型`model`、对应的数据加载模块`train_dataloader`、优化模块`optimizer`\n",
    "\n",
    "&emsp; 通过`progress_bar`设定进度条格式，默认为`\"auto\"`，此外还有`\"rich\"`、`\"raw\"`和`None`\n",
    "\n",
    "&emsp; &emsp; 但对于`\"auto\"`和`\"rich\"`格式，在`jupyter`中，进度条会在训练结束后会被丢弃\n",
    "\n",
    "&emsp; 通过`n_epochs`设定优化迭代轮数，默认为20；全部`Trainer`的全部变量与函数可以通过`dir(trainer)`查询"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b51b7a2d",
   "metadata": {
    "pycharm": {
     "is_executing": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "from fastNLP import Trainer\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    driver=\"torch\",\n",
    "    device='cuda',\n",
    "    train_dataloader=train_dataloader,\n",
    "    optimizers=optimizer,\n",
    "    n_epochs=10,                    # 设定迭代轮数 \n",
    "    progress_bar=\"auto\"             # 设定进度条格式\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e202d6e",
   "metadata": {},
   "source": [
    "通过使用`Trainer`类的`run`函数，进行训练\n",
    "\n",
    "&emsp; 其中，可以通过参数`num_train_batch_per_epoch`决定每个`epoch`运行多少个`batch`后停止，默认全部\n",
    "\n",
    "&emsp; `run`函数完成后在`jupyter`中没有输出保留，此外，通过`help(trainer.run)`可以查询`run`函数的详细内容"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ba047ead",
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "trainer.run()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c16c5fa4",
   "metadata": {},
   "source": [
    "通过从`fastNLP`库中导入`Evaluator`类，初始化`evaluator`实例，对模型进行评测\n",
    "\n",
    "&emsp; 需要导入预先定义好的模型`model`、对应的数据加载模块`evaluate_dataloader`\n",
    "\n",
    "&emsp; 需要注意的是评测方法`metrics`，设定为形如`{'acc': fastNLP.core.metrics.Accuracy()}`的字典\n",
    "\n",
    "&emsp; 类似地，也可以通过`progress_bar`限定进度条格式，默认为`\"auto\"`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1c6b6b36",
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "from fastNLP import Evaluator\n",
    "from fastNLP import Accuracy\n",
    "\n",
    "evaluator = Evaluator(\n",
    "    model=model,\n",
    "    driver=trainer.driver,          # 需要使用 trainer 已经启动的 driver\n",
    "    device=None,\n",
    "    dataloaders=evaluate_dataloader,\n",
    "    metrics={'acc': Accuracy()}     # 需要严格使用此种形式的字典\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8157bb9b",
   "metadata": {},
   "source": [
    "通过使用`Evaluator`类的`run`函数，进行训练\n",
    "\n",
    "&emsp; 其中，可以通过参数`num_eval_batch_per_dl`决定每个`evaluate_dataloader`运行多少个`batch`停止，默认全部\n",
    "\n",
    "&emsp; 最终，输出形如`{'acc#acc': acc}`的字典，在`jupyter`中，进度条会在评测结束后会被丢弃"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f7cb0165",
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'acc#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.31</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'total#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'correct#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">31.0</span><span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.31\u001b[0m, \u001b[32m'total#acc'\u001b[0m: \u001b[1;36m100.0\u001b[0m, \u001b[32m'correct#acc'\u001b[0m: \u001b[1;36m31.0\u001b[0m\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'acc#acc': 0.31, 'total#acc': 100.0, 'correct#acc': 31.0}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluator.run()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd9f68fa",
   "metadata": {},
   "source": [
    "### 3.2 trainer 内部初始化的 evaluator \n",
    "\n",
    "通过在初始化`trainer`实例时加入`evaluate_dataloaders`和`metrics`，可以实现在训练过程中进行评测\n",
    "\n",
    "&emsp; 通过`progress_bar`同时设定训练和评估进度条格式，在`jupyter`中，在进度条训练结束后会被丢弃\n",
    "\n",
    "&emsp; 但是中间的评估结果仍会保留；**通过`evaluate_every`设定评估频率**，可以为负数、正数或者函数：\n",
    "\n",
    "&emsp; &emsp; **为负数时**，**表示每隔几个`epoch`评估一次**；**为正数时**，**则表示每隔几个`batch`评估一次**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "183c7d19",
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    driver=trainer.driver,      # 因为是在同个脚本中，这里的 driver 同样需要重用\n",
    "    train_dataloader=train_dataloader,\n",
    "    evaluate_dataloaders=evaluate_dataloader,\n",
    "    metrics={'acc': Accuracy()},\n",
    "    optimizers=optimizer,\n",
    "    n_epochs=10,  \n",
    "    evaluate_every=-1,          # 表示每个 epoch 的结束进行评估\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "714cc404",
   "metadata": {},
   "source": [
    "通过使用`Trainer`类的`run`函数，进行训练\n",
    "\n",
    "&emsp; 还可以通过**参数`num_eval_sanity_batch`决定每次训练前运行多少个`evaluate_batch`进行评测**，**默认为`2`**\n",
    "\n",
    "&emsp; 之所以“先评测后训练”，是为了保证训练很长时间的数据，不会在评测阶段出问题，故作此**试探性评测**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "2e4daa2c",
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[18:28:25] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO    </span> Running evaluator sanity check for <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> batches.              <a href=\"file://../fastNLP/core/controllers/trainer.py\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">trainer.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../fastNLP/core/controllers/trainer.py#592\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">592</span></a>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[2;36m[18:28:25]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO    \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches.              \u001b]8;id=549287;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=645362;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
       "</pre>\n"
      ],
      "text/plain": [
       "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.31</span>,\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">31.0</span>\n",
       "<span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\n",
       "  \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.31\u001b[0m,\n",
       "  \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
       "  \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m31.0\u001b[0m\n",
       "\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
       "</pre>\n"
      ],
      "text/plain": [
       "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.33</span>,\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">33.0</span>\n",
       "<span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\n",
       "  \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.33\u001b[0m,\n",
       "  \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
       "  \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m33.0\u001b[0m\n",
       "\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
       "</pre>\n"
      ],
      "text/plain": [
       "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.34</span>,\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">34.0</span>\n",
       "<span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\n",
       "  \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.34\u001b[0m,\n",
       "  \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
       "  \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m34.0\u001b[0m\n",
       "\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
       "</pre>\n"
      ],
      "text/plain": [
       "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.36</span>,\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">36.0</span>\n",
       "<span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\n",
       "  \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n",
       "  \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
       "  \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n",
       "\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
       "</pre>\n"
      ],
      "text/plain": [
       "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.36</span>,\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">36.0</span>\n",
       "<span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\n",
       "  \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n",
       "  \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
       "  \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n",
       "\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
       "</pre>\n"
      ],
      "text/plain": [
       "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.36</span>,\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">36.0</span>\n",
       "<span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\n",
       "  \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n",
       "  \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
       "  \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n",
       "\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
       "</pre>\n"
      ],
      "text/plain": [
       "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.36</span>,\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">36.0</span>\n",
       "<span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\n",
       "  \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n",
       "  \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
       "  \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n",
       "\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">8</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
       "</pre>\n"
      ],
      "text/plain": [
       "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.36</span>,\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">36.0</span>\n",
       "<span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\n",
       "  \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n",
       "  \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
       "  \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n",
       "\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">9</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
       "</pre>\n"
      ],
      "text/plain": [
       "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.37</span>,\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">37.0</span>\n",
       "<span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\n",
       "  \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.37\u001b[0m,\n",
       "  \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
       "  \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m37.0\u001b[0m\n",
       "\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
       "</pre>\n"
      ],
      "text/plain": [
       "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.4</span>,\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">40.0</span>\n",
       "<span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\n",
       "  \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.4\u001b[0m,\n",
       "  \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
       "  \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m40.0\u001b[0m\n",
       "\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "trainer.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c4e9c619",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'acc#acc': 0.4, 'total#acc': 100.0, 'correct#acc': 40.0}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.evaluator.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bc7cb4a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
